{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchvision\n",
    "from sklearn.cluster import KMeans\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly import loss, models\n",
    "from lightly.models import utils\n",
    "from lightly.models.modules import heads\n",
    "from lightly.models.modules.memory_bank import MemoryBankModule\n",
    "from lightly.transforms.smog_transform import SMoGTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SMoGModel(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # create a ResNet backbone and remove the classification head\n",
    "        resnet = models.ResNetGenerator(\"resnet-18\")\n",
    "        self.backbone = nn.Sequential(\n",
    "            *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1)\n",
    "        )\n",
    "\n",
    "        # create a model based on ResNet\n",
    "        self.projection_head = heads.SMoGProjectionHead(512, 2048, 128)\n",
    "        self.prediction_head = heads.SMoGPredictionHead(128, 2048, 128)\n",
    "        self.backbone_momentum = copy.deepcopy(self.backbone)\n",
    "        self.projection_head_momentum = copy.deepcopy(self.projection_head)\n",
    "        utils.deactivate_requires_grad(self.backbone_momentum)\n",
    "        utils.deactivate_requires_grad(self.projection_head_momentum)\n",
    "\n",
    "        # smog\n",
    "        self.n_groups = 300\n",
    "        memory_bank_size = 10000\n",
    "        self.memory_bank = MemoryBankModule(size=(memory_bank_size, 128))\n",
    "        # create our loss\n",
    "        group_features = torch.nn.functional.normalize(\n",
    "            torch.rand(self.n_groups, 128), dim=1\n",
    "        ).to(self.device)\n",
    "        self.smog = heads.SMoGPrototypes(group_features=group_features, beta=0.99)\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:\n",
    "        features = features.cpu().numpy()\n",
    "        kmeans = KMeans(self.n_groups).fit(features)\n",
    "        clustered = torch.from_numpy(kmeans.cluster_centers_).float()\n",
    "        clustered = torch.nn.functional.normalize(clustered, dim=1)\n",
    "        return clustered\n",
    "\n",
    "    def _reset_group_features(self):\n",
    "        # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n",
    "        features = self.memory_bank.bank\n",
    "        group_features = self._cluster_features(features)\n",
    "        self.smog.set_group_features(group_features)\n",
    "\n",
    "    def _reset_momentum_weights(self):\n",
    "        # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n",
    "        self.backbone_momentum = copy.deepcopy(self.backbone)\n",
    "        self.projection_head_momentum = copy.deepcopy(self.projection_head)\n",
    "        utils.deactivate_requires_grad(self.backbone_momentum)\n",
    "        utils.deactivate_requires_grad(self.projection_head_momentum)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        if self.global_step > 0 and self.global_step % 300 == 0:\n",
    "            # reset group features and weights every 300 iterations\n",
    "            self._reset_group_features()\n",
    "            self._reset_momentum_weights()\n",
    "        else:\n",
    "            # update momentum\n",
    "            utils.update_momentum(self.backbone, self.backbone_momentum, 0.99)\n",
    "            utils.update_momentum(\n",
    "                self.projection_head, self.projection_head_momentum, 0.99\n",
    "            )\n",
    "\n",
    "        (x0, x1) = batch[0]\n",
    "\n",
    "        if batch_idx % 2:\n",
    "            # swap batches every second iteration\n",
    "            x0, x1 = x1, x0\n",
    "\n",
    "        x0_features = self.backbone(x0).flatten(start_dim=1)\n",
    "        x0_encoded = self.projection_head(x0_features)\n",
    "        x0_predicted = self.prediction_head(x0_encoded)\n",
    "        x1_features = self.backbone_momentum(x1).flatten(start_dim=1)\n",
    "        x1_encoded = self.projection_head_momentum(x1_features)\n",
    "\n",
    "        # update group features and get group assignments\n",
    "        assignments = self.smog.assign_groups(x1_encoded)\n",
    "        group_features = self.smog.get_updated_group_features(x0_encoded)\n",
    "        logits = self.smog(x0_predicted, group_features, temperature=0.1)\n",
    "        self.smog.set_group_features(group_features)\n",
    "\n",
    "        loss = self.criterion(logits, assignments)\n",
    "\n",
    "        # use memory bank to periodically reset the group features with k-means\n",
    "        self.memory_bank(x0_encoded, update=True)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        params = (\n",
    "            list(self.backbone.parameters())\n",
    "            + list(self.projection_head.parameters())\n",
    "            + list(self.prediction_head.parameters())\n",
    "        )\n",
    "        optim = torch.optim.SGD(\n",
    "            params,\n",
    "            lr=0.01,\n",
    "            momentum=0.9,\n",
    "            weight_decay=1e-6,\n",
    "        )\n",
    "        return optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SMoGModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = SMoGTransform(\n",
    "    crop_sizes=(32, 32),\n",
    "    crop_counts=(1, 1),\n",
    "    gaussian_blur_probs=(0.0, 0.0),\n",
    "    crop_min_scales=(0.2, 0.2),\n",
    "    crop_max_scales=(1.0, 1.0),\n",
    ")\n",
    "dataset = torchvision.datasets.CIFAR10(\n",
    "    \"datasets/cifar10\", download=True, transform=transform\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n",
    "trainer.fit(model=model, train_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
